{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 决赛数据预处理\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境 & 配置\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.926 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import traceback\n",
    "import random\n",
    "import sys\n",
    "import pprint\n",
    "import jieba\n",
    "\n",
    "sys.path.insert(0, \"/home/team55/notespace/zengbin\")\n",
    "\n",
    "from jddc.config import PreConfig\n",
    "from jddc.utils import write_file, read_file, save_to_pkl, read_from_pkl, create_logger\n",
    "from jddc.obj import Session, Sentence\n",
    "from jddc.seg import JiebaSeg, jieba_tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = PreConfig()\n",
    "logger = create_logger(name='pre', log_file=conf.log_file, cmd=conf.cmd_log)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 - 基础处理\n",
    "---\n",
    "\n",
    "\n",
    "** 处理内容 **\n",
    "\n",
    "1. 读入chat.txt，对超过7个字段的数据行进行处理，整理成7个字段（将第7个字段之后的所有字段与第7个字段合并）,结果文件 chat_pred.txt\n",
    "2. 读入chat_pred.txt，根据session_id划分对话，将每一个对话所有行归集，结果文件 chat_splited.txt\n",
    "3. 读入chat_splited.txt，解析每一个会话，合并连续q、a，提取订单信息等，结果文件 session_parsed.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### a - 合并多余字段"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_surplus_col(input_file, output_file):\n",
    "    \"\"\"读入chat.txt，对超过7个字段的数据行进行处理，\n",
    "    整理成7个字段（将第7个字段之后的所有字段与\n",
    "    第7个字段合并）,结果文件 chat_pred.txt\"\"\"\n",
    "    with open(input_file, 'r', encoding=\"utf-8-sig\") as f:\n",
    "        lines = f.readlines()\n",
    "    \n",
    "    for i, line in tqdm(enumerate(lines), desc=\"merge_surplus_col\", ncols=100):\n",
    "        cols = line.strip(\"\\r\\n\").replace(\"\\n\", \"\").split('\\t')\n",
    "        if len(cols) > 7:\n",
    "            line_pred = cols[:6]\n",
    "            text = \"，\".join(cols[6:])\n",
    "            line_pred.append(text)\n",
    "            lines[i] = '\\t'.join(line_pred)\n",
    "    \n",
    "    write_file(output_file, lines, mode='w')\n",
    "    return lines\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = merge_surplus_col(conf.file_chat, conf.file_chat_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### b - 将数据集按session拆分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat_split_by_session(input_file, output_file):\n",
    "    \"\"\"读入chat_pred.txt，根据session_id划分对话，将每一个对话所有行归集，结果文件 chat_splited.txt\"\"\"\n",
    "    lines = read_file(input_file)\n",
    "\n",
    "    chat_splited = []\n",
    "\n",
    "    # 初始化 session info\n",
    "    sess_info = {\n",
    "        \"session_id\": lines[0].split('\\t')[0],\n",
    "        \"lines\": []\n",
    "    }\n",
    "\n",
    "    for i, line in tqdm(enumerate(lines), desc=\"chat_split_by_session\", ncols=100):\n",
    "        try:\n",
    "            cols = line.split(\"\\t\")\n",
    "            line_cols = {\n",
    "                \"id\": cols[0],\n",
    "                \"user\": cols[1],\n",
    "                \"waiter_send\": cols[2],\n",
    "                \"transfer\": cols[3],\n",
    "                \"repeat\": cols[4],\n",
    "                \"sku\": cols[5],\n",
    "                \"content\": cols[6]\n",
    "            }\n",
    "            assert len(cols) == 7, \"总共有七个字段，当前行有%i个字段\" % len(cols)\n",
    "            if sess_info['session_id'] == line_cols['id']:\n",
    "                sess_info['lines'].append(line)\n",
    "            else:\n",
    "                chat_splited.append(sess_info)\n",
    "                sess_info = {\n",
    "                    \"session_id\": line_cols['id'],\n",
    "                    \"lines\": [line]\n",
    "                }\n",
    "\n",
    "            # 保存最后一个session\n",
    "            if i+1 == len(lines):\n",
    "                chat_splited.append(sess_info)\n",
    "                \n",
    "        except Exception as e:\n",
    "            logger.error('line error: %s' % line)\n",
    "            logger.exception(e)\n",
    "    \n",
    "    chat_splited = [str(x) for x in chat_splited]\n",
    "    write_file(output_file, chat_splited, mode='w')\n",
    "    return chat_splited"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "chat_splited = chat_split_by_session(conf.file_chat_pred, conf.file_chat_splited)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### c - 会话解析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_session(sess_info):\n",
    "    \"\"\"解析单个session，获取order_id等信息\n",
    "\n",
    "    返回结果：\n",
    "        {\n",
    "        \"session_id\": 会话id,\n",
    "        \"user_id\": user_id,\n",
    "        \"order_id\": order_id,\n",
    "        \"sku\": 商品品类,\n",
    "        \"transfer\": 是否转移,\n",
    "        \"repeat\": 是否重复,\n",
    "        \"lines\": 原始数据行,\n",
    "        \"qas_merged\": 合并之后的对话记录\n",
    "        }\n",
    "    \"\"\"\n",
    "    lines = sess_info['lines']\n",
    "    user_id = lines[0].split(\"\\t\")[1]\n",
    "    transfer = list(set([line.split(\"\\t\")[3] for line in lines\n",
    "                         if line.split(\"\\t\")[3] != '']))\n",
    "    repeat = list(set([line.split(\"\\t\")[4] for line in lines\n",
    "                       if line.split(\"\\t\")[4] != '']))\n",
    "    sku = list(set([line.split(\"\\t\")[5] for line in lines\n",
    "                    if line.split(\"\\t\")[5] != '']))\n",
    "\n",
    "    # 提取订单号\n",
    "    contents = \"\\t\".join([line.split(\"\\t\")[6] for line in lines])\n",
    "    pat_oid = re.compile(r'(ORDERID_\\d{8})')\n",
    "    order_id = list(set(pat_oid.findall(contents)))\n",
    "\n",
    "    # 合并q/a\n",
    "    qas = [(line.split(\"\\t\")[2], line.split(\"\\t\")[6]) for line in lines]\n",
    "    qas_merged = []\n",
    "    current_sender = qas[0][0]\n",
    "    content = qas[0][1]\n",
    "    for i, qa in enumerate(qas[1:]):\n",
    "        if current_sender == qa[0]:\n",
    "            content += \"\\t\" + qa[1]\n",
    "            # 尾行处理 \n",
    "            # 必须用行标来定位尾行，不能用内容\n",
    "            if i == len(qas) - 2:\n",
    "                qa_ = (current_sender, content)\n",
    "                qas_merged.append(qa_)\n",
    "        else:\n",
    "            qa_ = (current_sender, content)\n",
    "            qas_merged.append(qa_)\n",
    "            current_sender = qa[0]\n",
    "            content = qa[1]\n",
    "            # 尾行处理\n",
    "            if i == len(qas) - 2:\n",
    "                qa_ = (current_sender, content)\n",
    "                qas_merged.append(qa_)\n",
    "\n",
    "    return {\n",
    "        \"session_id\": sess_info['session_id'],\n",
    "        \"user_id\": user_id,\n",
    "        \"order_id\": order_id,\n",
    "        \"sku\": sku,\n",
    "        \"transfer\": transfer,\n",
    "        \"repeat\": repeat,\n",
    "        \"lines\": lines,\n",
    "        \"qas_merged\": qas_merged\n",
    "    }\n",
    "\n",
    "def chat_session_parse(input_file, output_file):\n",
    "    \"\"\"读入chat_splited.txt，解析每一个会话，合并连续q、a，\n",
    "    提取订单信息等，结果文件 sessions.txt\"\"\"\n",
    "    chat_splited = read_file(input_file)\n",
    "    chat_splited = [eval(x) for x in chat_splited]\n",
    "    session_parsed = []\n",
    "    for sess_info in tqdm(chat_splited, desc='chat_session_parse', ncols=100):\n",
    "        try:\n",
    "            sess_parsed = _parse_session(sess_info)\n",
    "            session_parsed.append(sess_parsed)\n",
    "        except Exception:\n",
    "            print(sess_info)\n",
    "            traceback.print_exc()\n",
    "    session_parsed = [str(x) for x in session_parsed]\n",
    "    write_file(output_file, session_parsed, mode='w')\n",
    "    return session_parsed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "session_parsed = chat_session_parse(conf.file_chat_splited, conf.file_session_parsed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sessions = read_file(conf.file_session_parsed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/team55/notespace/data/temp/all_sessions.pkl'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf.pkl_sessions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = Session(eval(sessions[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s.multi_qa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_pkl_sessions(sessions):\n",
    "    pkl_sessions = conf.pkl_sessions\n",
    "    sess_objs = []\n",
    "    for sess in tqdm(sessions, ncols=100, desc=\"create sess_objs\"):\n",
    "        obj = Session(eval(sess))\n",
    "        if obj.data_quality:\n",
    "            sess_objs.append(obj)\n",
    "    print(\"save new sessions to %s\" % pkl_sessions)\n",
    "    save_to_pkl(pkl_sessions, data=sess_objs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "create sess_objs: 100%|█████████████████████████████████| 1025140/1025140 [04:19<00:00, 3947.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save new sessions to /home/team55/notespace/data/temp/all_sessions.pkl\n"
     ]
    }
   ],
   "source": [
    "create_pkl_sessions(sessions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_sessions():\n",
    "    pkl_sessions = conf.pkl_sessions\n",
    "    if os.path.exists(pkl_sessions):\n",
    "        print(\"load sessions from %s\" % pkl_sessions)\n",
    "        sess_objs = read_from_pkl(pkl_sessions)\n",
    "    else:\n",
    "        print(\"refresh sessions ...\")\n",
    "        sessions = read_file(conf.file_session_parsed)\n",
    "        sess_objs = []\n",
    "        for sess in tqdm(sessions, ncols=100, desc=\"create sess_objs\"):\n",
    "            obj = Session(eval(sess))\n",
    "            if obj.data_quality:\n",
    "                sess_objs.append(obj)\n",
    "        print(\"save new sessions to %s\" % pkl_sessions)\n",
    "        save_to_pkl(pkl_sessions, data=sess_objs)\n",
    "    return sess_objs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load sessions from /home/team55/notespace/data/temp/all_sessions.pkl\n"
     ]
    }
   ],
   "source": [
    "sessions = load_sessions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "989930"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sessions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### d - 查找所有脱敏词\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = []\n",
    "for sess in tqdm(sessions, ncols=100, desc=\"find desensitization\"):\n",
    "    qas = [x[1] for x in sess.qas_merged]\n",
    "    texts.extend(qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pat1 = re.compile(\"(#.*?\\[.*?\\])\")\n",
    "pat2 = re.compile(\"(\\[.*?\\])\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_pat1 = pat1.findall(\" \".join(texts))\n",
    "res_pat2 = pat2.findall(\" \".join(texts))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 - 数据清洗 & 拆分\n",
    "\n",
    "---\n",
    "\n",
    "1. 根据qaqaq的长度进行清洗，仅保留长度在(30, 500)范围内的对话"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### a - 选取1000个session进行开发测试\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_1000 = random.sample(sessions, 1000)\n",
    "save_to_pkl(conf.pkl_mqa_1000, s_1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_10000 = random.sample(sessions, 10000)\n",
    "save_to_pkl(conf.pkl_mqa_10000, s_10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_100000 = random.sample(sessions, 100000)\n",
    "save_to_pkl(conf.pkl_mqa_10000, s_100000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 - 构造用于训练词向量的数据集\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sentences_for_embedding():\n",
    "    if os.path.exists(conf.file_stopwords):\n",
    "        jb_seg = JiebaSeg(file_stopwords=conf.file_stopwords)\n",
    "    \n",
    "    texts = []\n",
    "    for sess in tqdm(sessions, ncols=100, desc=\"create texts for embedding\"):\n",
    "        qas = [x[1].replace('\\t', \" \") for x in sess.qas_merged]\n",
    "        texts.extend(qas)\n",
    "\n",
    "    sentences = []\n",
    "    for s in tqdm(texts, ncols=100, desc=\"cut sentence\"):\n",
    "        s_cuted = Sentence(s, seg=jb_seg)\n",
    "        sentences.append(s_cuted.cuted_sentence)\n",
    "    sentences = [\" \".join(x) for x in sentences]\n",
    "    \n",
    "    write_file(file=conf.file_texts_for_embedding, content=sentences, mode='w', encoding='utf-8')\n",
    "    return sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = create_sentences_for_embedding()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 - 单轮训练集构建\n",
    "\n",
    "---\n",
    "\n",
    "1. 构造QAQAQ+Q形式的训练集\n",
    "2. 仅使用Q来进行匹配"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 - 多轮训练集构建\n",
    "\n",
    "----\n",
    "\n",
    "1. 字段列表：session id, question, answer"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
